# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from collections import defaultdict
from contextlib import contextmanager

from sqlalchemy import text


def get_mssql_table_constraints(conn, table_name) -> dict[str, dict[str, list[str]]]:
    """
    Return the primary and unique constraint along with column name.

    Some tables like `task_instance` are missing the primary key constraint
    name and the name is auto-generated by the SQL server, so this function
    helps to retrieve any primary or unique constraint name.

    :param conn: sql connection object
    :param table_name: table name
    :return: a dictionary of ((constraint name, constraint type), column name) of table
    """
    query = text(
        f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
     FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
     JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
     WHERE tc.TABLE_NAME = '{table_name}' AND
     (tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE'
     or UPPER(tc.CONSTRAINT_TYPE) = 'FOREIGN KEY')
    """
    )
    result = conn.execute(query).fetchall()
    constraint_dict = defaultdict(lambda: defaultdict(list))
    for constraint, constraint_type, col_name in result:
        constraint_dict[constraint_type][constraint].append(col_name)
    return constraint_dict


@contextmanager
def disable_sqlite_fkeys(op):
    if op.get_bind().dialect.name == "sqlite":
        op.execute("PRAGMA foreign_keys=off")
        yield op
        op.execute("PRAGMA foreign_keys=on")
    else:
        yield op
